"""
Dijkstra优化算法
----
假定:
1)图任意两点可达
2)边权重(距离)大于0
"""
import unittest
from math import inf, isinf
from queue import PriorityQueue
import sys


def dijkstra(_N: int, _edges: dict, v: int):
    """
    :param _N: 点个数
    :param _edges: 边集合  ((起始点id, 结束点id), 距离), 点id取值范围为: {1, 2, ..., _N}
    :param v: 计算的点id
    :return: 点v到其他个点的最短距离, 点v到其他各点的中转集合
    点v到其他个点的最短距离: {点id: 最短距离}
    点v到其他各点的中转集合: 格式为{j: u}, 即v到j点需要经过u点, 用于获得最短路径, 是个可变变量
    """
    # 最短距离缓存, 是个可变变量
    dist = dict()
    # 路径中转, 格式: {j: u}, 即v到j点需要经过u点, 用于获得最短路径, 是个可变变量
    transit = {}
    # 初始化最短距离缓存
    for k in range(1, _N + 1):
        if (v, k) in _edges:
            dist[k] = _edges[(v, k)]
        else:
            dist[k] = inf

    # 已经找到最优路径的点集合, 可变
    _solved = {v}
    # 未找到最优路径的点集合, 可变
    _not_solved = {i for i in range(1, _N + 1) if i != v}

    _cnt = 0
    while len(_not_solved) > 0 and _cnt <= _N:
        # S1, 从_not_solved中找距离最近的点集, 更新点集
        u = -1
        dist_of_u = inf
        for j in _not_solved:
            if dist[j] <= dist_of_u:
                u = j
                dist_of_u = dist[j]
        _solved.add(u)
        _not_solved.remove(u)

        # S2, 松弛_un_resolved集合中的最短路径
        if not isinf(dist_of_u):
            for k, _d in dist.items():
                if (u, k) in _edges and k not in _solved:
                    if _edges[(u, k)] + dist_of_u < _d:
                        dist[k] = _edges[(u, k)] + dist_of_u
                        transit[k] = u
        _cnt += 1
    return dist, transit


# 1)_solved和_not_solved重复, 可以节省一个, 此外采用数组存, 空间相同读取效率要高过in判断
# 2)我们不会更新_solved中节点的dist, 因此u==j可以直接return
# 3)_not_solved可以使用堆排序, 我们可以使用
# 4)松弛时只松弛当前找到最优点的邻居节点(未solved的节点), 需要多建一个邻居节点的缓存

def dijkstra2(_N: int, _edges: dict, v: int, target: int):
    """
    :param target: 目标点id
    :param _N: 点个数
    :param _edges: 边集合  ((起始点id, 结束点id), 距离), 点id取值范围为: {1, 2, ..., _N}
    :param v: 计算的点id
    :return: 点v到其他个点的最短距离, 点v到其他各点的中转集合
    点v到其他个点的最短距离: {点id: 最短距离}
    点v到其他各点的中转集合: 格式为{j: u}, 即v到j点需要经过u点, 用于获得最短路径, 是个可变变量
    """
    # 最短距离缓存, 是个可变变量
    dist = [inf for _ in range(0, _N + 1)]
    # 路径中转, 格式: {j: u}, 即v到j点需要经过u点, 用于获得最短路径, 是个可变变量

    # 邻居节点缓存, 多建一个
    adjacency = [[] for _ in range(0, _N + 1)]
    for ((s, e), d) in _edges.items():
        adjacency[s].append((e, d))

    transit = {}
    # 初始化最短距离缓存
    dist[v] = 0

    solved = set()

    # 未找到最优路径的点集合, 采用堆排序的方式放入内存, 只放需要探索的点, 可变
    _not_solved = PriorityQueue()
    _not_solved.put((0, v))

    _cnt = 0
    while not _not_solved.empty() and _cnt <= _N:
        # S1, 从_not_solved中找距离最近的点集, 更新点集
        (dist_of_u, u) = _not_solved.get()

        # S2, 松弛_un_resolved集合中的最短路径
        if u == target:
            return dist_of_u, get_path_dijkstra(transit, v, u)

        solved.add(u)
        neighbors = adjacency[u]
        for (e, d) in neighbors:
            if e not in solved and dist[e] > dist_of_u + d:
                dist[e] = dist_of_u + d
                if u != v:
                    transit[e] = u
                _not_solved.put((dist_of_u + d, e))
        _cnt += 1
    return None, None


def get_path_dijkstra(_transit: dict, _i: int, _j: int) -> list:
    """
    获取图上的路径
    :param _transit: 中转矩阵 格式为{j: u}, 即v到j点需要经过u点, 用于获得最短路径, 是个可变变量, 如果j不在其中则表示直接可达
    :param _i: Dijkstra算法选定的起始点
    :param _j: 目标点
    :return: 路径
    """
    if _j in _transit:
        pre = get_path_dijkstra(_transit, _i, _transit[_j])
        return pre + [_j]
    else:
        return [_i, _j]


class TestDijkstra(unittest.TestCase):
    def test_init(self):
        self.input_and_output = [
            # N: 表示位置个数N: 2 <= N <= 10000
            # M: 道路条数M: 1 <= M <= 100000
            # S: 起点位置编号S: 1 <= S <= N
            # T: 快递位置编号T: 1 <= T <= N
            ('''
            6 9 1 4
            1 3 5
            1 4 30
            2 1 2
            2 5 8
            3 6 7
            3 2 15
            5 4 4
            6 4 10
            6 5 18''', 22, '1->3->6->4'),  # 22  1->3->6->4
            ('''
            6 9 1 5
            1 3 5
            1 4 30
            2 1 2
            2 5 8
            3 6 7
            3 2 15
            5 4 4
            6 4 10
            6 5 18''', 28, '1->3->2->5'),  # 28 1->3->2->5
            ('''
            6 9 1 6
            1 3 5
            1 4 30
            2 1 2
            2 5 8
            3 6 7
            3 2 15
            5 4 4
            6 4 10
            6 5 18''', 12, '1->3->6'),  # 12 1->3->6
            ('''
            4 4 1 3
            1 2 3
            2 4 3
            4 3 2
            3 1 1''', 8, '1->2->4->3'),
            ('''
            3 3 1 3
            1 2 3
            2 3 3
            3 1 1''', 6, '1->2->3'),
            ('''
            2 2 1 2
            1 2 3
            2 1 3''', 3, '1->2')
        ]

    def test_1dijkstra(self):
        self.test_init()
        for info, expected_distance, expected_trip in self.input_and_output:
            information = [each.strip() for each in info.splitlines() if len(each.strip()) > 0]
            _N, _M, _S, _T = map(int, information[0].split(' '))
            # 边
            edges = {}
            cnt = 0
            for line in information[1:]:
                s, e, d = map(int, line.split(' '))
                edges[(s, e)] = d
                cnt += 1
                if cnt >= _M:
                    break

            _dist, _transit = dijkstra(_N, edges, _S)
            distance = _dist[_T]
            trip = '->'.join([str(each) for each in get_path_dijkstra(_transit, _S, _T)])

            print('distance: %d, expected distance: %d, trip: %s, expected trip: %s' %
                  (distance, expected_distance, trip, expected_trip))
            assert distance == expected_distance
            assert trip == expected_trip

    def test_dijkstra2(self):
        self.test_init()
        print("dijkstra")
        for info, expected_distance, expected_trip in self.input_and_output:
            information = [each.strip() for each in info.splitlines() if len(each.strip()) > 0]
            _N, _M, _S, _T = map(int, information[0].split(' '))
            # 边
            edges = {}
            cnt = 0
            for line in information[1:]:
                s, e, d = map(int, line.split(' '))
                if ((s, e) in edges and edges[(s, e)] > d) or (s, e) not in edges:
                    if s == e:
                        continue
                    edges[(s, e)] = d
                cnt += 1
                if cnt >= _M:
                    break

            distance, trip_nodes = dijkstra2(_N, edges, _S, _T)
            trip = '->'.join(map(str, trip_nodes))
            print('distance: %d, expected distance: %d, trip: %s, expected trip: %s' %
                  (distance, expected_distance, trip, expected_trip))
            assert distance == expected_distance
            print(len(trip), len(expected_trip))
            assert trip == expected_trip


sys.setrecursionlimit(100000)  # 设置递归深度
TestDijkstra().test_dijkstra2()
